{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "import nemo\n",
    "from nemo import logging\n",
    "from nemo.utils.lr_policies import WarmupAnnealing\n",
    "\n",
    "import nemo.collections.nlp as nemo_nlp\n",
    "from nemo.collections.nlp.data import NemoBertTokenizer\n",
    "from nemo.collections.nlp.nm.trainables import PunctCapitTokenClassifier\n",
    "from nemo.backends.pytorch.common.losses import CrossEntropyLossNM, LossAggregatorNM\n",
    "from nemo.collections.nlp.callbacks.punctuation_capitalization_callback import eval_iter_callback, eval_epochs_done_callback\n",
    "from nemo.collections.nlp.data.datasets.datasets_utils import calc_class_weights\n",
    "\n",
    "DATA_DIR = \"PATH_TO_WHERE_THE_DATA_IS\"\n",
    "WORK_DIR = \"PATH_TO_WHERE_TO_STORE_CHECKPOINTS_AND_LOGS\"\n",
    "\n",
    "# See the list of available pre-trained models by calling\n",
    "# the nemo_nlp.nm.trainables.get_bert_models_list()\n",
    "PRETRAINED_BERT_MODEL = \"bert-base-uncased\"\n",
    "\n",
    "# model parameters\n",
    "BATCHES_PER_STEP = 1\n",
    "BATCH_SIZE = 128\n",
    "CLASSIFICATION_DROPOUT = 0.1\n",
    "MAX_SEQ_LENGTH = 64\n",
    "NUM_EPOCHS = 10\n",
    "LEARNING_RATE = 0.00002\n",
    "LR_WARMUP_PROPORTION = 0.1\n",
    "OPTIMIZER = \"adam\"\n",
    "STEP_FREQ = 200 # determines how often loss will be printed and checkpoint saved\n",
    "PUNCT_NUM_FC_LAYERS = 3\n",
    "NUM_SAMPLES = 100000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Download and preprocess the data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook we're going to use a subset of English examples from the [Tatoeba collection of sentences](https://tatoeba.org/eng), set NUM_SAMPLES=-1 and consider including other datasets to improve the performance of the model. Use [NeMo/examples/nlp/token_classification/get_tatoeba_data.py](https://github.com/NVIDIA/NeMo/blob/master/examples/nlp/token_classification/get_tatoeba_data.py) to download and preprocess the Tatoeba data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This should take about a minute since the data is already downloaded in the previous step\n",
    "\n",
    "! python get_tatoeba_data.py --data_dir $DATA_DIR --num_sample $NUM_SAMPLES"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After the previous step, you should have a `DATA_DIR` folder with the following files:\n",
    "- labels_train.txt\n",
    "- labels_dev.txt\n",
    "- text_train.txt\n",
    "- text_dev.txt\n",
    "\n",
    "The format of the data described in NeMo docs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define Neural Modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Instantiate neural factory with supported backend\n",
    "nf = nemo.core.NeuralModuleFactory(\n",
    "\n",
    "    # If you're training with multiple GPUs, you should handle this value with\n",
    "    # something like argparse. See examples/nlp/token_classification.py for an example.\n",
    "    local_rank=None,\n",
    "\n",
    "    # If you're training with mixed precision, this should be set to mxprO1 or mxprO2.\n",
    "    # See https://nvidia.github.io/apex/amp.html#opt-levels for more details.\n",
    "    optimization_level=\"O1\",\n",
    "    \n",
    "    # Define path to the directory you want to store your results\n",
    "    log_dir=WORK_DIR,\n",
    "\n",
    "    # If you're training with multiple GPUs, this should be set to\n",
    "    # nemo.core.DeviceType.AllGpu\n",
    "    placement=nemo.core.DeviceType.GPU)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# If you're using a standard BERT model, you should do it like this. To see the full\n",
    "# list of MegatronBERT/BERT/ALBERT/RoBERTa model names, call nemo_nlp.nm.trainables.get_pretrained_lm_models_list()\n",
    "\n",
    "bert_model = nemo_nlp.nm.trainables.get_pretrained_lm_model(\n",
    "    pretrained_model_name=PRETRAINED_BERT_MODEL)\n",
    "\n",
    "tokenizer = nemo.collections.nlp.data.tokenizers.get_tokenizer(\n",
    "    tokenizer_name=\"nemobert\",\n",
    "    pretrained_model_name=PRETRAINED_BERT_MODEL)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Describe training DAG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data_layer = nemo_nlp.nm.data_layers.PunctuationCapitalizationDataLayer(\n",
    "     tokenizer=tokenizer,\n",
    "     text_file=os.path.join(DATA_DIR, 'text_train.txt'),\n",
    "     label_file=os.path.join(DATA_DIR, 'labels_train.txt'),\n",
    "     max_seq_length=MAX_SEQ_LENGTH,\n",
    "     batch_size=BATCH_SIZE)\n",
    "\n",
    "punct_label_ids = train_data_layer.dataset.punct_label_ids\n",
    "capit_label_ids = train_data_layer.dataset.capit_label_ids\n",
    "\n",
    "hidden_size = bert_model.hidden_size\n",
    "\n",
    "# Define classifier for Punctuation and Capitalization tasks\n",
    "classifier = PunctCapitTokenClassifier(\n",
    "            hidden_size=hidden_size,\n",
    "            punct_num_classes=len(punct_label_ids),\n",
    "            capit_num_classes=len(capit_label_ids),\n",
    "            dropout=0.1,\n",
    "            punct_num_layers=3,\n",
    "            capit_num_layers=2,\n",
    "        )\n",
    "\n",
    "# If you don't want to use weighted loss for Punctuation task, use class_weights=None\n",
    "punct_label_freqs = train_data_layer.dataset.punct_label_frequencies\n",
    "class_weights = calc_class_weights(punct_label_freqs)\n",
    "\n",
    "# define loss\n",
    "punct_loss = CrossEntropyLossNM(logits_ndim=3, weight=class_weights)\n",
    "capit_loss = CrossEntropyLossNM(logits_ndim=3)\n",
    "task_loss = LossAggregatorNM(num_inputs=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask, punct_labels, capit_labels = train_data_layer()\n",
    "\n",
    "hidden_states = bert_model(\n",
    "    input_ids=input_ids,\n",
    "    token_type_ids=input_type_ids,\n",
    "    attention_mask=input_mask)\n",
    "\n",
    "punct_logits, capit_logits = classifier(hidden_states=hidden_states)\n",
    "\n",
    "punct_loss = punct_loss(\n",
    "    logits=punct_logits,\n",
    "    labels=punct_labels,\n",
    "    loss_mask=loss_mask)\n",
    "\n",
    "capit_loss = capit_loss(\n",
    "    logits=capit_logits,\n",
    "    labels=capit_labels,\n",
    "    loss_mask=loss_mask)\n",
    "\n",
    "task_loss = task_loss(\n",
    "    loss_1=punct_loss,\n",
    "    loss_2=capit_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Describe evaluation DAG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Note that you need to specify punct_label_ids and capit_label_ids  - mapping form labels to label_ids generated\n",
    "# during creation of the train_data_layer to make sure that the mapping is correct in case some of the labels from\n",
    "# the train set are missing in the dev set.\n",
    "\n",
    "eval_data_layer = nemo_nlp.nm.data_layers.PunctuationCapitalizationDataLayer(\n",
    "    tokenizer=tokenizer,\n",
    "    text_file=os.path.join(DATA_DIR, 'text_dev.txt'),\n",
    "    label_file=os.path.join(DATA_DIR, 'labels_dev.txt'),\n",
    "    max_seq_length=MAX_SEQ_LENGTH,\n",
    "    batch_size=BATCH_SIZE,\n",
    "    punct_label_ids=punct_label_ids,\n",
    "    capit_label_ids=capit_label_ids)\n",
    "\n",
    "eval_input_ids, eval_input_type_ids, eval_input_mask, _, eval_subtokens_mask, eval_punct_labels, eval_capit_labels\\\n",
    "    = eval_data_layer()\n",
    "\n",
    "hidden_states = bert_model(\n",
    "    input_ids=eval_input_ids,\n",
    "    token_type_ids=eval_input_type_ids,\n",
    "    attention_mask=eval_input_mask)\n",
    "\n",
    "eval_punct_logits, eval_capit_logits = classifier(hidden_states=hidden_states)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create callbacks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "callback_train = nemo.core.SimpleLossLoggerCallback(\n",
    "    tensors=[task_loss, punct_loss, capit_loss, punct_logits, capit_logits],\n",
    "    print_func=lambda x: logging.info(\"Loss: {:.3f}\".format(x[0].item())),\n",
    "    step_freq=STEP_FREQ)\n",
    "\n",
    "train_data_size = len(train_data_layer)\n",
    "\n",
    "# If you're training on multiple GPUs, this should be\n",
    "# train_data_size / (batch_size * batches_per_step * num_gpus)\n",
    "steps_per_epoch = int(train_data_size / (BATCHES_PER_STEP * BATCH_SIZE))\n",
    "print ('Number of steps per epoch: ', steps_per_epoch)\n",
    "\n",
    "# Callback to evaluate the model\n",
    "callback_eval = nemo.core.EvaluatorCallback(\n",
    "    eval_tensors=[eval_punct_logits,\n",
    "    eval_capit_logits,\n",
    "    eval_punct_labels,\n",
    "    eval_capit_labels,\n",
    "    eval_subtokens_mask],\n",
    "    user_iter_callback=lambda x, y: eval_iter_callback(x, y),\n",
    "    user_epochs_done_callback=lambda x: eval_epochs_done_callback(x,\n",
    "                                                      punct_label_ids,\n",
    "                                                      capit_label_ids),\n",
    "    eval_step=steps_per_epoch)\n",
    "\n",
    "# Callback to store checkpoints\n",
    "ckpt_callback = nemo.core.CheckpointCallback(\n",
    "    folder=nf.checkpoint_dir,\n",
    "    step_freq=STEP_FREQ)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr_policy = WarmupAnnealing(NUM_EPOCHS * steps_per_epoch,\n",
    "                            warmup_ratio=LR_WARMUP_PROPORTION)\n",
    "\n",
    "nf.train(tensors_to_optimize=[task_loss],\n",
    "         callbacks=[callback_train, callback_eval, ckpt_callback],\n",
    "         lr_policy=lr_policy,\n",
    "         batches_per_step=BATCHES_PER_STEP,\n",
    "         optimizer=OPTIMIZER,\n",
    "         optimization_params={\"num_epochs\": NUM_EPOCHS,\n",
    "                              \"lr\": LEARNING_RATE})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "10 epochs of training on the subset of data, should take about 20 minutes on a single V100 GPU.\n",
    "The model performance should be similar to the following:\n",
    "    \n",
    "                precision    recall  f1-score   support\n",
    "           O       1.00      0.99      0.99    137268\n",
    "           ,       0.58      0.95      0.72      2347\n",
    "           .       0.99      1.00      1.00     19078\n",
    "           ?       0.98      0.99      0.99      1151\n",
    "\n",
    "    accuracy                           0.99    159844\n",
    "    macro avg       0.89      0.98     0.92    159844\n",
    "    weighted avg    0.99      0.99     0.99    159844\n",
    "\n",
    "                precision    recall  f1-score   support\n",
    "           O       1.00      1.00      1.00    136244\n",
    "           U       1.00      0.99      0.99     23600\n",
    "\n",
    "    accuracy                           1.00    159844\n",
    "    macro avg       1.00      1.00     1.00    159844\n",
    "    weighted avg    1.00      1.00     1.00    159844"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the list of queiries for inference\n",
    "queries = ['can i help you',\n",
    "           'yes please',\n",
    "           'we bought four shirts from the nvidia gear store in santa clara',\n",
    "           'we bought four shirts one mug and ten thousand titan rtx graphics cards',\n",
    "           'the more you buy the more you save']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "infer_data_layer = nemo_nlp.nm.data_layers.BertTokenClassificationInferDataLayer(\n",
    "    queries=queries,\n",
    "    tokenizer=tokenizer,\n",
    "    max_seq_length=MAX_SEQ_LENGTH,\n",
    "    batch_size=1)\n",
    "\n",
    "input_ids, input_type_ids, input_mask, _, subtokens_mask = infer_data_layer()\n",
    "\n",
    "hidden_states = bert_model(\n",
    "    input_ids=input_ids,\n",
    "    token_type_ids=input_type_ids,\n",
    "    attention_mask=input_mask)\n",
    "\n",
    "punct_logits, capit_logits = classifier(hidden_states=hidden_states)\n",
    "\n",
    "evaluated_tensors = nf.infer(tensors=[punct_logits, capit_logits, subtokens_mask],\n",
    "                             checkpoint_dir=WORK_DIR + '/checkpoints')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# helper functions\n",
    "def concatenate(lists):\n",
    "    return np.concatenate([t.cpu() for t in lists])\n",
    "\n",
    "punct_ids_to_labels = {punct_label_ids[k]: k for k in punct_label_ids}\n",
    "capit_ids_to_labels = {capit_label_ids[k]: k for k in capit_label_ids}\n",
    "\n",
    "punct_logits, capit_logits, subtokens_mask = [concatenate(tensors) for tensors in evaluated_tensors]\n",
    "punct_preds = np.argmax(punct_logits, axis=2)\n",
    "capit_preds = np.argmax(capit_logits, axis=2)\n",
    "\n",
    "for i, query in enumerate(queries):\n",
    "    print(f'Query: {query}')\n",
    "\n",
    "    punct_pred = punct_preds[i][subtokens_mask[i] > 0.5]\n",
    "    capit_pred = capit_preds[i][subtokens_mask[i] > 0.5]\n",
    "\n",
    "    words = query.strip().split()\n",
    "    if len(punct_pred) != len(words) or len(capit_pred) != len(words):\n",
    "        raise ValueError('Pred and words must be of the same length')\n",
    "\n",
    "    output = ''\n",
    "    for j, w in enumerate(words):\n",
    "        punct_label = punct_ids_to_labels[punct_pred[j]]\n",
    "        capit_label = capit_ids_to_labels[capit_pred[j]]\n",
    "\n",
    "        if capit_label != 'O':\n",
    "            w = w.capitalize()\n",
    "        output += w\n",
    "        if punct_label != 'O':\n",
    "            output += punct_label\n",
    "        output += ' '\n",
    "    print(f'Combined: {output.strip()}\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The inference output should look something like this:<br>\n",
    "\n",
    "Query: can i help you<br>\n",
    "Combined: Can I help you?<br>\n",
    "\n",
    "Query: yes please<br>\n",
    "Combined: Yes, please.<br>\n",
    "\n",
    "Query: we bought four shirts from the nvidia gear store in santa clara<br>\n",
    "Combined: We bought four shirts from the Nvidia gear store in Santa Clara.<br>\n",
    "            \n",
    "Query: we bought four shirts one mug and ten thousand titan rtx graphics cards<br>\n",
    "Combined: We bought four shirts, one mug, and ten thousand Titan Rtx graphics cards.<br>\n",
    "\n",
    "Query: the more you buy the more you save<br>\n",
    "Combined: The more you buy, the more you save.<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Set NUM_SAMPLES=-1 and consider including other datasets to improve the performance of the model.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
